{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perceptron"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# prepare data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = sns.load_dataset(\"iris\")\n",
    "class_map = {'virginica':-1, 'versicolor':1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.concat([data.loc[data['species']=='versicolor'],data.loc[data['species']=='virginica']])\n",
    "data['label'] = data['species'].map(class_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = data['label'].values\n",
    "x = data[['sepal_length','sepal_width','petal_length','petal_width']].values\n",
    "\n",
    "len_datasset = labels.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = np.zeros((4))\n",
    "b = np.zeros((1))\n",
    "\n",
    "iter_num = 700\n",
    "lr = 0.0001\n",
    "valid_interval = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter:99,acc:0.5\n",
      "iter:199,acc:0.95\n",
      "iter:299,acc:0.5\n",
      "iter:399,acc:0.66\n",
      "iter:499,acc:0.5\n",
      "iter:599,acc:0.83\n",
      "iter:699,acc:0.98\n"
     ]
    }
   ],
   "source": [
    "for i in range(iter_num):\n",
    "    index = random.randint(0,len_datasset-1)\n",
    "    label, xi = labels[index], x[index,:]\n",
    "    \n",
    "    yi = w.dot(xi) + b\n",
    "    \n",
    "    if yi * label <= 0:\n",
    "        w = w + lr * label * xi\n",
    "        \n",
    "        b = b + lr * label\n",
    "    \n",
    "    if (i+1)%valid_interval == 0:\n",
    "        output = np.sign(x.dot(w) + b)\n",
    "        acc = np.equal(output, labels).sum()/len_datasset\n",
    "        print(f'iter:{i},acc:{round(acc,4)}')\n",
    "    \n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "        1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,\n",
       "        1.,  1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,\n",
       "        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1.,\n",
       "       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
       "       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
       "       -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
       "       -1., -1., -1., -1., -1., -1., -1., -1., -1.])"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
